import torch
from useful_tools import create_module
from .ffn import (
    NaiveFFN,
    StarFFN,
)

class create_NaiveFFN(create_module):
    def __init__(
        self,
        input_dim, 
        hidden_dim=None, 
        output_dim=None, 
        dropout=0.1, 
        precision=torch.float32, 
        require_bias=True, 
        activate=None
        ):
        super(create_NaiveFFN, self).__init__(NaiveFFN)
    
class create_StarFFN(create_module):
    def __init__(
        self,
        input_dim, 
        hidden_dim=None, 
        output_dim=None, 
        dropout=0.1, 
        precision=torch.float32, 
        require_bias=True, 
        activate=None
        ):
        super(create_StarFFN, self).__init__(StarFFN)